{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4945612e",
   "metadata": {},
   "source": [
    "<h1> Encrypted Inference-Linear Regression</h1>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9fdce7bf",
   "metadata": {},
   "source": [
    "Author:\n",
    "<ul>\n",
    "    <li>Hrishikesh Kamath - <a href=\"https://twitter.com/kamathhrishi\">Twitter</a> - <a href=\"https://github.com/kamathhrishi\">Github</a>\n",
    "</ul>\n",
    "Encrypted Inference is the process of performing inference with machine learning models such that model owner cannot observe the true input data nor can the data owner see the true model weights. The weights and data are encrypted by splitting them into shares and performiming computations according to a protocol. The general class of methods know as <b>Secure Multi Party Computation (SMPC)</b>. \n",
    "\n",
    "Below figure depicts MPC for ML models for 2 parties. \n",
    "\n",
    "<img height=\"600px\" width=\"600px\" src=\"Images/smpc_illustration.png\"></img>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "943011cc",
   "metadata": {},
   "source": [
    "In this example we use Virtual Machine to demonstrate performing inference using SMPC. That is the workers are present on the same PC. If you want to understand how to perform remotely check out duet tutorials. \n",
    "In this example, we train a Linear regression model in plaintext on Boston Housing Dataset. Then we use the model for performing encrypted inference on test data. This tutorial uses protocol Falcon for 3 parties and SPDZ for 3 and 5 parties. \n",
    "\n",
    "In SyMPC the computation between parties occurs using a orchestrator which describes how computations should take place."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "6cfe9942",
   "metadata": {},
   "outputs": [],
   "source": [
    "#External libraries\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "2b25677e",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Import torch\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.utils.data as data_utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "67713542",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7fc7e044c2d0>"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#Set a manual seed to maintain consistency\n",
    "torch.manual_seed(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e6b5a2a",
   "metadata": {},
   "source": [
    "<h2>Data Loading and Processing</h2>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "c09bbb21",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2021-09-01 22:03:12--  https://archive.ics.uci.edu/ml/machine-learning-databases/housing/housing.data\n",
      "Resolving archive.ics.uci.edu (archive.ics.uci.edu)... 128.195.10.252\n",
      "Connecting to archive.ics.uci.edu (archive.ics.uci.edu)|128.195.10.252|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 49082 (48K) [application/x-httpd-php]\n",
      "Saving to: ‘housing.data.7’\n",
      "\n",
      "housing.data.7      100%[===================>]  47.93K   108KB/s    in 0.4s    \n",
      "\n",
      "2021-09-01 22:03:14 (108 KB/s) - ‘housing.data.7’ saved [49082/49082]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "#Download Boston housing dataset\n",
    "!wget https://archive.ics.uci.edu/ml/machine-learning-databases/housing/housing.data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "a38de5c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Import dataset and add headers\n",
    "dataset=pd.read_csv(\"housing.data\",delim_whitespace=True,\n",
    "                    names=[\"crim\",\"zn\",\"indus\",\n",
    "                           \"chas\",\"nox\",\"rm\",\n",
    "                           \"age\",\"dis\",\"rad\",\n",
    "                           \"tax\",\"ptratio\",\"black\",\n",
    "                           \"lstat\",\"medv\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "860112db",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>crim</th>\n",
       "      <th>zn</th>\n",
       "      <th>indus</th>\n",
       "      <th>chas</th>\n",
       "      <th>nox</th>\n",
       "      <th>rm</th>\n",
       "      <th>age</th>\n",
       "      <th>dis</th>\n",
       "      <th>rad</th>\n",
       "      <th>tax</th>\n",
       "      <th>ptratio</th>\n",
       "      <th>black</th>\n",
       "      <th>lstat</th>\n",
       "      <th>medv</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.00632</td>\n",
       "      <td>18.0</td>\n",
       "      <td>2.31</td>\n",
       "      <td>0</td>\n",
       "      <td>0.538</td>\n",
       "      <td>6.575</td>\n",
       "      <td>65.2</td>\n",
       "      <td>4.0900</td>\n",
       "      <td>1</td>\n",
       "      <td>296.0</td>\n",
       "      <td>15.3</td>\n",
       "      <td>396.90</td>\n",
       "      <td>4.98</td>\n",
       "      <td>24.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.02731</td>\n",
       "      <td>0.0</td>\n",
       "      <td>7.07</td>\n",
       "      <td>0</td>\n",
       "      <td>0.469</td>\n",
       "      <td>6.421</td>\n",
       "      <td>78.9</td>\n",
       "      <td>4.9671</td>\n",
       "      <td>2</td>\n",
       "      <td>242.0</td>\n",
       "      <td>17.8</td>\n",
       "      <td>396.90</td>\n",
       "      <td>9.14</td>\n",
       "      <td>21.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.02729</td>\n",
       "      <td>0.0</td>\n",
       "      <td>7.07</td>\n",
       "      <td>0</td>\n",
       "      <td>0.469</td>\n",
       "      <td>7.185</td>\n",
       "      <td>61.1</td>\n",
       "      <td>4.9671</td>\n",
       "      <td>2</td>\n",
       "      <td>242.0</td>\n",
       "      <td>17.8</td>\n",
       "      <td>392.83</td>\n",
       "      <td>4.03</td>\n",
       "      <td>34.7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.03237</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2.18</td>\n",
       "      <td>0</td>\n",
       "      <td>0.458</td>\n",
       "      <td>6.998</td>\n",
       "      <td>45.8</td>\n",
       "      <td>6.0622</td>\n",
       "      <td>3</td>\n",
       "      <td>222.0</td>\n",
       "      <td>18.7</td>\n",
       "      <td>394.63</td>\n",
       "      <td>2.94</td>\n",
       "      <td>33.4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.06905</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2.18</td>\n",
       "      <td>0</td>\n",
       "      <td>0.458</td>\n",
       "      <td>7.147</td>\n",
       "      <td>54.2</td>\n",
       "      <td>6.0622</td>\n",
       "      <td>3</td>\n",
       "      <td>222.0</td>\n",
       "      <td>18.7</td>\n",
       "      <td>396.90</td>\n",
       "      <td>5.33</td>\n",
       "      <td>36.2</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      crim    zn  indus  chas    nox     rm   age     dis  rad    tax  \\\n",
       "0  0.00632  18.0   2.31     0  0.538  6.575  65.2  4.0900    1  296.0   \n",
       "1  0.02731   0.0   7.07     0  0.469  6.421  78.9  4.9671    2  242.0   \n",
       "2  0.02729   0.0   7.07     0  0.469  7.185  61.1  4.9671    2  242.0   \n",
       "3  0.03237   0.0   2.18     0  0.458  6.998  45.8  6.0622    3  222.0   \n",
       "4  0.06905   0.0   2.18     0  0.458  7.147  54.2  6.0622    3  222.0   \n",
       "\n",
       "   ptratio   black  lstat  medv  \n",
       "0     15.3  396.90   4.98  24.0  \n",
       "1     17.8  396.90   9.14  21.6  \n",
       "2     17.8  392.83   4.03  34.7  \n",
       "3     18.7  394.63   2.94  33.4  \n",
       "4     18.7  396.90   5.33  36.2  "
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#Visualize and look at columns and rows of dataset\n",
    "dataset.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "03e1045a",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Split data into features and target variables\n",
    "features = dataset.drop(\"medv\",axis=1)\n",
    "targets = dataset[\"medv\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "2f0ed65a",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Normalize features\n",
    "features = features.apply(\n",
    "    lambda x: (x - x.mean()) / x.std()\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "b7c38c32",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>crim</th>\n",
       "      <th>zn</th>\n",
       "      <th>indus</th>\n",
       "      <th>chas</th>\n",
       "      <th>nox</th>\n",
       "      <th>rm</th>\n",
       "      <th>age</th>\n",
       "      <th>dis</th>\n",
       "      <th>rad</th>\n",
       "      <th>tax</th>\n",
       "      <th>ptratio</th>\n",
       "      <th>black</th>\n",
       "      <th>lstat</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>-0.419367</td>\n",
       "      <td>0.284548</td>\n",
       "      <td>-1.286636</td>\n",
       "      <td>-0.272329</td>\n",
       "      <td>-0.144075</td>\n",
       "      <td>0.413263</td>\n",
       "      <td>-0.119895</td>\n",
       "      <td>0.140075</td>\n",
       "      <td>-0.981871</td>\n",
       "      <td>-0.665949</td>\n",
       "      <td>-1.457558</td>\n",
       "      <td>0.440616</td>\n",
       "      <td>-1.074499</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>-0.416927</td>\n",
       "      <td>-0.487240</td>\n",
       "      <td>-0.592794</td>\n",
       "      <td>-0.272329</td>\n",
       "      <td>-0.739530</td>\n",
       "      <td>0.194082</td>\n",
       "      <td>0.366803</td>\n",
       "      <td>0.556609</td>\n",
       "      <td>-0.867024</td>\n",
       "      <td>-0.986353</td>\n",
       "      <td>-0.302794</td>\n",
       "      <td>0.440616</td>\n",
       "      <td>-0.491953</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>-0.416929</td>\n",
       "      <td>-0.487240</td>\n",
       "      <td>-0.592794</td>\n",
       "      <td>-0.272329</td>\n",
       "      <td>-0.739530</td>\n",
       "      <td>1.281446</td>\n",
       "      <td>-0.265549</td>\n",
       "      <td>0.556609</td>\n",
       "      <td>-0.867024</td>\n",
       "      <td>-0.986353</td>\n",
       "      <td>-0.302794</td>\n",
       "      <td>0.396035</td>\n",
       "      <td>-1.207532</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>-0.416338</td>\n",
       "      <td>-0.487240</td>\n",
       "      <td>-1.305586</td>\n",
       "      <td>-0.272329</td>\n",
       "      <td>-0.834458</td>\n",
       "      <td>1.015298</td>\n",
       "      <td>-0.809088</td>\n",
       "      <td>1.076671</td>\n",
       "      <td>-0.752178</td>\n",
       "      <td>-1.105022</td>\n",
       "      <td>0.112920</td>\n",
       "      <td>0.415751</td>\n",
       "      <td>-1.360171</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>-0.412074</td>\n",
       "      <td>-0.487240</td>\n",
       "      <td>-1.305586</td>\n",
       "      <td>-0.272329</td>\n",
       "      <td>-0.834458</td>\n",
       "      <td>1.227362</td>\n",
       "      <td>-0.510674</td>\n",
       "      <td>1.076671</td>\n",
       "      <td>-0.752178</td>\n",
       "      <td>-1.105022</td>\n",
       "      <td>0.112920</td>\n",
       "      <td>0.440616</td>\n",
       "      <td>-1.025487</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>501</th>\n",
       "      <td>-0.412820</td>\n",
       "      <td>-0.487240</td>\n",
       "      <td>0.115624</td>\n",
       "      <td>-0.272329</td>\n",
       "      <td>0.157968</td>\n",
       "      <td>0.438881</td>\n",
       "      <td>0.018654</td>\n",
       "      <td>-0.625178</td>\n",
       "      <td>-0.981871</td>\n",
       "      <td>-0.802418</td>\n",
       "      <td>1.175303</td>\n",
       "      <td>0.386834</td>\n",
       "      <td>-0.417734</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>502</th>\n",
       "      <td>-0.414839</td>\n",
       "      <td>-0.487240</td>\n",
       "      <td>0.115624</td>\n",
       "      <td>-0.272329</td>\n",
       "      <td>0.157968</td>\n",
       "      <td>-0.234316</td>\n",
       "      <td>0.288648</td>\n",
       "      <td>-0.715931</td>\n",
       "      <td>-0.981871</td>\n",
       "      <td>-0.802418</td>\n",
       "      <td>1.175303</td>\n",
       "      <td>0.440616</td>\n",
       "      <td>-0.500355</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>503</th>\n",
       "      <td>-0.413038</td>\n",
       "      <td>-0.487240</td>\n",
       "      <td>0.115624</td>\n",
       "      <td>-0.272329</td>\n",
       "      <td>0.157968</td>\n",
       "      <td>0.983986</td>\n",
       "      <td>0.796661</td>\n",
       "      <td>-0.772919</td>\n",
       "      <td>-0.981871</td>\n",
       "      <td>-0.802418</td>\n",
       "      <td>1.175303</td>\n",
       "      <td>0.440616</td>\n",
       "      <td>-0.982076</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>504</th>\n",
       "      <td>-0.407361</td>\n",
       "      <td>-0.487240</td>\n",
       "      <td>0.115624</td>\n",
       "      <td>-0.272329</td>\n",
       "      <td>0.157968</td>\n",
       "      <td>0.724955</td>\n",
       "      <td>0.736268</td>\n",
       "      <td>-0.667776</td>\n",
       "      <td>-0.981871</td>\n",
       "      <td>-0.802418</td>\n",
       "      <td>1.175303</td>\n",
       "      <td>0.402826</td>\n",
       "      <td>-0.864446</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>505</th>\n",
       "      <td>-0.414590</td>\n",
       "      <td>-0.487240</td>\n",
       "      <td>0.115624</td>\n",
       "      <td>-0.272329</td>\n",
       "      <td>0.157968</td>\n",
       "      <td>-0.362408</td>\n",
       "      <td>0.434302</td>\n",
       "      <td>-0.612640</td>\n",
       "      <td>-0.981871</td>\n",
       "      <td>-0.802418</td>\n",
       "      <td>1.175303</td>\n",
       "      <td>0.440616</td>\n",
       "      <td>-0.668397</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>506 rows × 13 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "         crim        zn     indus      chas       nox        rm       age  \\\n",
       "0   -0.419367  0.284548 -1.286636 -0.272329 -0.144075  0.413263 -0.119895   \n",
       "1   -0.416927 -0.487240 -0.592794 -0.272329 -0.739530  0.194082  0.366803   \n",
       "2   -0.416929 -0.487240 -0.592794 -0.272329 -0.739530  1.281446 -0.265549   \n",
       "3   -0.416338 -0.487240 -1.305586 -0.272329 -0.834458  1.015298 -0.809088   \n",
       "4   -0.412074 -0.487240 -1.305586 -0.272329 -0.834458  1.227362 -0.510674   \n",
       "..        ...       ...       ...       ...       ...       ...       ...   \n",
       "501 -0.412820 -0.487240  0.115624 -0.272329  0.157968  0.438881  0.018654   \n",
       "502 -0.414839 -0.487240  0.115624 -0.272329  0.157968 -0.234316  0.288648   \n",
       "503 -0.413038 -0.487240  0.115624 -0.272329  0.157968  0.983986  0.796661   \n",
       "504 -0.407361 -0.487240  0.115624 -0.272329  0.157968  0.724955  0.736268   \n",
       "505 -0.414590 -0.487240  0.115624 -0.272329  0.157968 -0.362408  0.434302   \n",
       "\n",
       "          dis       rad       tax   ptratio     black     lstat  \n",
       "0    0.140075 -0.981871 -0.665949 -1.457558  0.440616 -1.074499  \n",
       "1    0.556609 -0.867024 -0.986353 -0.302794  0.440616 -0.491953  \n",
       "2    0.556609 -0.867024 -0.986353 -0.302794  0.396035 -1.207532  \n",
       "3    1.076671 -0.752178 -1.105022  0.112920  0.415751 -1.360171  \n",
       "4    1.076671 -0.752178 -1.105022  0.112920  0.440616 -1.025487  \n",
       "..        ...       ...       ...       ...       ...       ...  \n",
       "501 -0.625178 -0.981871 -0.802418  1.175303  0.386834 -0.417734  \n",
       "502 -0.715931 -0.981871 -0.802418  1.175303  0.440616 -0.500355  \n",
       "503 -0.772919 -0.981871 -0.802418  1.175303  0.440616 -0.982076  \n",
       "504 -0.667776 -0.981871 -0.802418  1.175303  0.402826 -0.864446  \n",
       "505 -0.612640 -0.981871 -0.802418  1.175303  0.440616 -0.668397  \n",
       "\n",
       "[506 rows x 13 columns]"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "1e7111a4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0      24.0\n",
       "1      21.6\n",
       "2      34.7\n",
       "3      33.4\n",
       "4      36.2\n",
       "       ... \n",
       "501    22.4\n",
       "502    20.6\n",
       "503    23.9\n",
       "504    22.0\n",
       "505    11.9\n",
       "Name: medv, Length: 506, dtype: float64"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "targets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "085f5b7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Convert features and targets into torch tensors\n",
    "features = torch.tensor(features.values.astype(np.float32)) \n",
    "targets = torch.tensor(targets.values.astype(np.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "9d775369",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Arguments for training\n",
    "batch_size = 16\n",
    "epochs = 300\n",
    "train_test_split = 0.8\n",
    "lr = 0.001"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "83aaf309",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Split dataset into train and test\n",
    "train_indices=int(len(features)*train_test_split)\n",
    "\n",
    "train_x = features[:train_indices]\n",
    "train_y = targets[:train_indices]\n",
    "\n",
    "test_x = features[train_indices+1:]\n",
    "test_y = targets[train_indices+1:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "394b765d",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Divide dataset into batches\n",
    "def get_batches(X, y):\n",
    "    batches = []\n",
    "    for index in range(0,len(train_x)+1,batch_size):\n",
    "        batches.append((X[index:index+batch_size],y[index:index+batch_size]))\n",
    "    \n",
    "    return batches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "dfa2e229",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_batches=get_batches(train_x,train_y)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e98f4a3",
   "metadata": {},
   "source": [
    "<h1>Plaintext Training</h1>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "fba7d787",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Import syft\n",
    "import syft as sy\n",
    "sy.logger.remove()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "7f9c72b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Define Linear regression model\n",
    "class LinearSyNet(sy.Module):\n",
    "    def __init__(self, torch_ref):\n",
    "        super(LinearSyNet, self).__init__(torch_ref=torch_ref)\n",
    "        self.fc1 = self.torch_ref.nn.Linear(13,1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.fc1(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "8636bc4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Define model, loss function and optimizer\n",
    "model = LinearSyNet(torch)\n",
    "criterion = torch.nn.MSELoss(reduction='mean') \n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=lr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "b4192669",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0/300  Running Loss : 978.0777587890625 and test loss: 367.94952392578125\n",
      "Epoch 50/300  Running Loss : 52.76874542236328 and test loss: 108.14092254638672\n",
      "Epoch 100/300  Running Loss : 43.0278434753418 and test loss: 47.971736907958984\n",
      "Epoch 150/300  Running Loss : 41.08037185668945 and test loss: 30.054576873779297\n",
      "Epoch 200/300  Running Loss : 40.123294830322266 and test loss: 23.0802059173584\n",
      "Epoch 250/300  Running Loss : 39.57941436767578 and test loss: 20.64508628845215\n"
     ]
    }
   ],
   "source": [
    "#Training Loop\n",
    "for epoch in range(epochs):\n",
    "  running_loss = 0.0\n",
    "  for index in range(0,len(train_batches)):\n",
    "    # Clear gradient buffers because we don't want any gradient from previous epoch to carry forward, dont want to cummulate gradients\n",
    "    optimizer.zero_grad()\n",
    "\n",
    "    # get output from the model, given the inputs\n",
    "    outputs = model(train_batches[index][0]).reshape([-1])\n",
    "\n",
    "    # get loss for the predicted output\n",
    "    loss = criterion(outputs,train_batches[index][1])\n",
    "    running_loss += loss\n",
    "    # get gradients w.r.t to parameters\n",
    "    loss.backward()\n",
    "\n",
    "    # update parameters\n",
    "    optimizer.step()\n",
    "    \n",
    "  test_accuracy = criterion(model(test_x).reshape([-1]),test_y)\n",
    "  if((epoch%50)==0):\n",
    "     print(f\"Epoch {epoch}/{epochs}  Running Loss : {running_loss.item()/batch_size} and test loss: {test_accuracy.item()}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "da5db6b3",
   "metadata": {},
   "source": [
    "<h1>Plaintext Inference</h1>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "7d4ff80c",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Perform inference in plaintext\n",
    "start_time=time.time()\n",
    "plaintext_predictions = model(test_x).reshape([-1])\n",
    "end_time=time.time()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "e48307d9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MSE Loss:  20.227853775024414\n",
      "Inference time:  0.0007698535919189453 s\n"
     ]
    }
   ],
   "source": [
    "#Calculate inference time and MSELoss\n",
    "print(\"MSE Loss: \",criterion(plaintext_predictions,test_y).item())\n",
    "print(\"Inference time: \",str(end_time-start_time),\"s\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bc3a2e0c",
   "metadata": {},
   "source": [
    "<h1>Encrypted Inference</h1>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "e4ba3c9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "#SyMPC imports required for encrypted inference\n",
    "import sympc\n",
    "from sympc.session import Session\n",
    "from sympc.session import SessionManager\n",
    "from sympc.tensor import MPCTensor\n",
    "from sympc.protocol import Falcon,FSS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "099f2b61",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_clients(n_parties):\n",
    "  #Generate required number of syft clients and return them.\n",
    "\n",
    "  parties=[]\n",
    "  for index in range(n_parties): \n",
    "      parties.append(sy.VirtualMachine(name = \"worker\"+str(index)).get_root_client())\n",
    "\n",
    "  return parties"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "6cbab6f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def split_send(data,session):\n",
    "    \"\"\"Splits data into number of chunks equal to number of parties and distributes it to respective \n",
    "       parties.\n",
    "    \"\"\"\n",
    "    data_pointers = []\n",
    "    \n",
    "    split_size = int(len(data)/len(session.parties))+1\n",
    "    for index in range(0,len(session.parties)):\n",
    "        ptr=data[index*split_size:index*split_size+split_size].share(session=session)\n",
    "        data_pointers.append(ptr)\n",
    "        \n",
    "    return data_pointers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "632d366b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def inference(n_clients,protocol=None):\n",
    "    \n",
    "  # Get VM clients \n",
    "  parties=get_clients(n_clients)\n",
    "\n",
    "  # Setup the session for the computation\n",
    "  if(protocol):\n",
    "     session = Session(parties = parties,protocol = protocol)\n",
    "  else:\n",
    "     session = Session(parties = parties)\n",
    "        \n",
    "  SessionManager.setup_mpc(session)\n",
    "\n",
    "  #Split data and send data to clients\n",
    "  pointers = split_send(test_x,session)\n",
    "\n",
    "  #Encrypt model \n",
    "  mpc_model = model.share(session)\n",
    "\n",
    "  #Encrypt test data\n",
    "  #test_data=MPCTensor(secret=test_x, session = session)\n",
    "\n",
    "  #Perform inference and measure time taken\n",
    "  start_time = time.time()\n",
    "    \n",
    "  results = []\n",
    "    \n",
    "  for ptr in pointers:\n",
    "     encrypted_results = mpc_model(ptr)\n",
    "     plaintext_results = encrypted_results.reconstruct()\n",
    "     results.append(plaintext_results)\n",
    "        \n",
    "  end_time = time.time()\n",
    "\n",
    "  print(f\"Time for inference: {end_time-start_time}s\")\n",
    "    \n",
    "  predictions = torch.cat(results).reshape([-1])\n",
    "\n",
    "  #Calculate Loss\n",
    "  print(\"MSE Loss: \",criterion(predictions,test_y).item())\n",
    "    \n",
    "  return predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "6b35ec4f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Time for inference: 0.11805987358093262s\n",
      "MSE Loss:  20.227657318115234\n"
     ]
    }
   ],
   "source": [
    "predictions=inference(3,Falcon(\"semi-honest\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ec10fe5a",
   "metadata": {},
   "source": [
    "We can see that the prediction values and mean squared error values are almost the same as final model. Small differences are due to precision loss."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "c1bea240",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Index 0\n",
      "Encrypted Prediction Output 2.636962890625\n",
      "Plaintext Prediction Output 2.6367740631103516\n",
      "Expected Prediction: 5.0\n",
      "\n",
      "\n",
      "Index 1\n",
      "Encrypted Prediction Output 4.957275390625\n",
      "Plaintext Prediction Output 4.957102298736572\n",
      "Expected Prediction: 11.899999618530273\n",
      "\n",
      "\n",
      "Index 2\n",
      "Encrypted Prediction Output 19.214553833007812\n",
      "Plaintext Prediction Output 19.214506149291992\n",
      "Expected Prediction: 27.899999618530273\n",
      "\n",
      "\n",
      "Index 3\n",
      "Encrypted Prediction Output 12.028610229492188\n",
      "Plaintext Prediction Output 12.028535842895508\n",
      "Expected Prediction: 17.200000762939453\n",
      "\n",
      "\n",
      "Index 4\n",
      "Encrypted Prediction Output 19.063995361328125\n",
      "Plaintext Prediction Output 19.063936233520508\n",
      "Expected Prediction: 27.5\n",
      "\n",
      "\n",
      "Index 5\n",
      "Encrypted Prediction Output 11.89825439453125\n",
      "Plaintext Prediction Output 11.898107528686523\n",
      "Expected Prediction: 15.0\n",
      "\n",
      "\n",
      "Index 6\n",
      "Encrypted Prediction Output 16.335174560546875\n",
      "Plaintext Prediction Output 16.335037231445312\n",
      "Expected Prediction: 17.200000762939453\n",
      "\n",
      "\n",
      "Index 7\n",
      "Encrypted Prediction Output -1.2559814453125\n",
      "Plaintext Prediction Output -1.2561841011047363\n",
      "Expected Prediction: 17.899999618530273\n",
      "\n",
      "\n",
      "Index 8\n",
      "Encrypted Prediction Output 8.844085693359375\n",
      "Plaintext Prediction Output 8.843910217285156\n",
      "Expected Prediction: 16.299999237060547\n",
      "\n",
      "\n",
      "Index 9\n",
      "Encrypted Prediction Output -8.952850341796875\n",
      "Plaintext Prediction Output -8.953043937683105\n",
      "Expected Prediction: 7.0\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for index in range(0,10):\n",
    "    print(f\"Index {index}\")\n",
    "    print(f\"Encrypted Prediction Output {predictions[index].item()}\")\n",
    "    print(f\"Plaintext Prediction Output {plaintext_predictions[index].item()}\")\n",
    "    print(f\"Expected Prediction: {test_y[index]}\")\n",
    "    print(\"\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a3de12f",
   "metadata": {},
   "source": [
    "<h1> Conclusion </h1>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "017c5bb7",
   "metadata": {},
   "source": [
    "Falcon can also provide a malicious security guarantee for an honest majority at the cost of higher inference time. Malicious security ensures that all the parties compute according to the protocol and do not deviate from protocol or tamper with shares. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "0671bd87",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Time for inference: 0.665564775466919s\n",
      "MSE Loss:  20.227645874023438\n"
     ]
    }
   ],
   "source": [
    "predictions=inference(3,Falcon(\"malicious\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e543e93",
   "metadata": {},
   "source": [
    "When we do not pass any protocol to session, SyMPC uses SPDZ and FSS protocol with semi-honest security type. \n",
    "\n",
    "SPDZ is used for multiplication and related operations (convolution,matmul,etc).\n",
    "Functional Secret Sharing (FSS) for other operations such as comparison, equality, maxpool, etc. \n",
    "\n",
    "FSS works only for 2 parties while SPDZ could extend to N parties. \n",
    "\n",
    "Linear regression uses only matmul which utilizes SPDZ protocol allowing us to run linear regression with several parties in this tutorial. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "id": "038bbe3c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Time for inference: 0.4993562698364258s\n",
      "MSE Loss:  20.227697372436523\n"
     ]
    }
   ],
   "source": [
    "predictions=inference(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "a8a8083d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Time for inference: 1.31923508644104s\n",
      "MSE Loss:  20.22768211364746\n"
     ]
    }
   ],
   "source": [
    "predictions=inference(5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8325a69",
   "metadata": {},
   "source": [
    "<center><h3> Comparison </h3></center>\n",
    "\n",
    "| Protocol | Security Type| Parties | Inference Time (s) |\n",
    "| --- | --- | --- | --- |\n",
    "| Plaintext | |  | 0.000534|\n",
    "| Falcon | Semi-honest | 3 | 0.118 |\n",
    "| Falcon | Malicious | 3 | 0.6659 |\n",
    "| SPDZ| Semi-honest | 3 | 0.4993 |\n",
    "| SPDZ | Semi-honest | 5 | 1.3192|\n",
    "\n",
    "<b>Note:</b> The above table is only for comparison. The inference time varies for different PC specs and CPU load.\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f092cecc",
   "metadata": {},
   "source": [
    "Falcon provides faster inference for a 3 parties setting. While, SPDZ allows inference for N number of parties. Both allow inference with almost same accuracy as plaintext. \n",
    "\n",
    "Although Falcon is much more faster, it isn't scalable because it is applicable for only 3 parties. Further, it is less secure. Since it uses 2-out-of-3 sharing where each party recieves two shares allowing 2 parties to reconstruct a secret without other party knowing. Falcon also assumes that majority of parties are honest (2 in this case). \n",
    "\n",
    "While, SPDZ and FSS distributes a single share to every party requiring shares from all parties for reconstruction ensuring no parties could collude. Currently SyMPC provides support for SPDZ and FSS with semi-honest security guarantee. This allows parties to tamper with shares leading to incorrect results. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dbef4aaf",
   "metadata": {},
   "source": [
    "<h3>What's next?</h3>\n",
    "\n",
    "SyMPC is still under development! We will add here more features as soon they are stable enough, stay tuned! 🕺\n",
    "\n",
    "If you enjoyed this tutorial, show your support by Starring SyMPC! 🙏"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
